# Copyright 2020 ByteDance Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf

from neurst.criterions import register_criterion
from neurst.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropy
from neurst.models.model_utils import input_length_to_nonpadding
from neurst.utils.flags_core import Flag


@register_criterion
class LabelSmoothedCrossEntropyWithKd(LabelSmoothedCrossEntropy):

    def __init__(self, args):
        """ Initializes the cross entropy with label smoothing.

        Args:
            args: A dict of full parameters.
        """
        super(LabelSmoothedCrossEntropyWithKd, self).__init__(args)
        self._kd_weight = args["kd_weight"]
        assert 0 <= self._kd_weight < 1.

    @staticmethod
    def class_or_method_args():
        """ Returns a list of args for flag definition. """
        flags = super(LabelSmoothedCrossEntropyWithKd,
                      LabelSmoothedCrossEntropyWithKd).class_or_method_args()
        flags.append(Flag("kd_weight", dtype=Flag.TYPE.FLOAT,
                          default=0.1, help="The weight for KD loss."))
        return flags

    def reduce_loss(self, model_inp, model_out):
        """ Reduces loss tensor for training according to the model inputs
            and outputs.

        Returns: A float tensor.
        """
        nll_sum, n_samples, n_tokens, kd_loss_sum, n_src_tokens = self(model_inp, model_out)
        nll = tf.reduce_sum(nll_sum) / tf.reduce_sum(n_tokens)
        kd = tf.reduce_sum(kd_loss_sum) / tf.reduce_sum(n_src_tokens)
        return nll * (1. - self._kd_weight) + kd * self._kd_weight

    def reduce_metrics(self, eval_res_list) -> dict:
        """ Reduces the metrics according to a list of returned value from `eval`.

        Args:
            eval_res_list: A list of tuples of numpy.ndarray generated by `self.__call__`
                and model.__call__.

        Returns:
            A dict of reduced metrics for evaluation.
        """
        nll_sum, nll_samples, nll_tokens, kd_loss_sum, n_src_tokens = 0., 0., 0., 0., 0.
        for _nll_sum, _nll_samples, _nll_tokens, _kd_loss_sum, _n_src_tokens in eval_res_list:
            nll_sum += tf.reduce_sum(_nll_sum)
            nll_samples += tf.reduce_sum(_nll_samples)
            nll_tokens += tf.reduce_sum(_nll_tokens)
            kd_loss_sum += tf.reduce_sum(_kd_loss_sum)
            n_src_tokens += tf.reduce_sum(_n_src_tokens)
        nll = nll_sum / nll_samples
        ppl = 2. ** (nll_sum / nll_tokens)
        return {"NLL": nll, "PPL": ppl,
                "KD_LOSS_per_sample": kd_loss_sum / nll_samples,
                "KD_PPL": 2. ** (kd_loss_sum / n_src_tokens)}

    def __call__(self, model_inp, model_out):
        """ Calculates.

        Args:
            model_inp: A dict containing the model inputs.
            model_out: The logits tensor or a dict containing the logits tensor.
                The logits tensor with shape [batch, max_len, vocab_size].

        Returns:
            The (nll_sum, num_of_samples(batch), num_of_tokens) with shape:
            nll_sum: [batch_size, ]
            num_of_samples: [1, ],
            num_of_tokens: [batch_size, ]
        """
        nll_sum, n_samples, n_tokens = super(LabelSmoothedCrossEntropyWithKd,
                                             self).__call__(model_inp, model_out)
        teacher = tf.stop_gradient(model_out["teacher_hidden_states"])
        student = model_out["student_hidden_states"]

        with tf.name_scope("kd_loss"):
            # [batch, src_len]
            kd_loss = tf.cast(tf.math.square(tf.norm(tensor=student - teacher, axis=-1)), tf.float32)
            if "src_padding" in model_inp:
                src_weights = tf.cast(1 - model_inp["src_padding"], tf.float32)
            else:
                src_weights = input_length_to_nonpadding(model_inp["src_length"],
                                                         tf.shape(model_inp["src"])[1], tf.float32)
            kd_loss_sum = tf.reduce_sum(kd_loss * src_weights, axis=1)
            n_src_tokens = tf.reduce_sum(src_weights, axis=1)
            return nll_sum, n_samples, n_tokens, kd_loss_sum, n_src_tokens
